{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Choosing Index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Give a great amount of indexes and quantizers, how to choose the one in the experiment/application? In this part, we will give a general suggestion on how to choose the one fits your need."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Packages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For CPU usage, run:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install -U faiss-cpu numpy h5py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For GPU on Linux x86_64 system, use Conda:\n",
    "\n",
    "```conda install -c pytorch -c nvidia faiss-gpu=1.8.0```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from urllib.request import urlretrieve\n",
    "import h5py\n",
    "import faiss\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, we'll use [SIFT1M](http://corpus-texmex.irisa.fr/), a very popular dataset for ANN evaluation, as our dataset to demonstrate the comparison.\n",
    "\n",
    "Run the following cell to download the dataset or you can also manually download from the repo [ann-benchmarks](https://github.com/erikbern/ann-benchmarks?tab=readme-ov-file#data-sets))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_url = \"http://ann-benchmarks.com/sift-128-euclidean.hdf5\"\n",
    "destination = \"data.hdf5\"\n",
    "urlretrieve(data_url, destination)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then load the data from the hdf5 file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000000, 128) float32\n",
      "(10000, 128) float32\n"
     ]
    }
   ],
   "source": [
    "with h5py.File('data.hdf5', 'r') as f:\n",
    "    corpus = f['train'][:]\n",
    "    query = f['test'][:]\n",
    "\n",
    "print(corpus.shape, corpus.dtype)\n",
    "print(query.shape, corpus.dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = corpus[0].shape[0]\n",
    "k = 100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Helper function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following is a helper function for computing recall."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute recall from the prediction results and ground truth\n",
    "def compute_recall(res, truth):\n",
    "    recall = 0\n",
    "    for i in range(len(res)):\n",
    "        intersect = np.intersect1d(res[i], truth[i])\n",
    "        recall += len(intersect) / len(res[i])\n",
    "    recall /= len(res)\n",
    "\n",
    "    return recall"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Flat Index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Flat index use brute force to search neighbors for each query. It guarantees the optimal result with 100% recall. Thus we use the result from it as the ground truth."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 69.2 ms, sys: 80.6 ms, total: 150 ms\n",
      "Wall time: 149 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "index = faiss.IndexFlatL2(d)\n",
    "index.add(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 17min 30s, sys: 1.62 s, total: 17min 31s\n",
      "Wall time: 2min 1s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "D, I_truth = index.search(query, k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. IVF Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 10.6 s, sys: 831 ms, total: 11.4 s\n",
      "Wall time: 419 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "nlist = 5\n",
    "nprob = 3\n",
    "\n",
    "quantizer = faiss.IndexFlatL2(d)\n",
    "index = faiss.IndexIVFFlat(quantizer, d, nlist)\n",
    "index.nprobe = nprob\n",
    "\n",
    "index.train(corpus)\n",
    "index.add(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 9min 15s, sys: 598 ms, total: 9min 16s\n",
      "Wall time: 12.5 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "D, I = index.search(query, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall: 0.9999189999999997\n"
     ]
    }
   ],
   "source": [
    "recall = compute_recall(I, I_truth)\n",
    "print(f\"Recall: {recall}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the test we can see that IVFFlatL2 has a pretty good promotion for the searching speed with a very tiny loss of recall."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. HNSW Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 11min 21s, sys: 595 ms, total: 11min 22s\n",
      "Wall time: 17 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "M = 64\n",
    "ef_search = 32\n",
    "ef_construction = 64\n",
    "\n",
    "index = faiss.IndexHNSWFlat(d, M)\n",
    "# set the two parameters before adding data\n",
    "index.hnsw.efConstruction = ef_construction\n",
    "index.hnsw.efSearch = ef_search\n",
    "\n",
    "index.add(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 5.14 s, sys: 3.94 ms, total: 5.14 s\n",
      "Wall time: 110 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "D, I = index.search(query, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall: 0.8963409999999716\n"
     ]
    }
   ],
   "source": [
    "recall = compute_recall(I, I_truth)\n",
    "print(f\"Recall: {recall}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the searching time of less than 1 second, we can see why HNSW is one of the best choice when looking for an extreme speed during searching phase. The reduction of recall is acceptable. But the  longer time during creation of index and large memory footprint need to be considered."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. LSH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 13.7 s, sys: 660 ms, total: 14.4 s\n",
      "Wall time: 12.1 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "nbits = d * 8\n",
    "\n",
    "index = faiss.IndexLSH(d, nbits)\n",
    "index.train(corpus)\n",
    "index.add(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 3min 20s, sys: 84.2 ms, total: 3min 20s\n",
      "Wall time: 5.64 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "D, I = index.search(query, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall: 0.5856720000000037\n"
     ]
    }
   ],
   "source": [
    "recall = compute_recall(I, I_truth)\n",
    "print(f\"Recall: {recall}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we covered in the last notebook, LSH is not a good choice when the data dimension is large. Here 128 is already burdened for LSH. As we can see, even we choose a relatively small `nbits` of d * 8, the index creating time and search time are still pretty long. And the recall of about 58.6% is not satisfactory."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Scalar Quantizer Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 550 ms, sys: 18 ms, total: 568 ms\n",
      "Wall time: 87.4 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "qtype = faiss.ScalarQuantizer.QT_8bit\n",
    "metric = faiss.METRIC_L2\n",
    "\n",
    "index = faiss.IndexScalarQuantizer(d, qtype, metric)\n",
    "index.train(corpus)\n",
    "index.add(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7min 36s, sys: 169 ms, total: 7min 36s\n",
      "Wall time: 12.7 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "D, I = index.search(query, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall: 0.990444999999872\n"
     ]
    }
   ],
   "source": [
    "recall = compute_recall(I, I_truth)\n",
    "print(f\"Recall: {recall}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here scalar quantizer index's performance looks very similar to the Flat index. Because the elements of vectors in the SIFT dataset are integers in the range of [0, 218]. Thus the index does not lose to much information during scalar quantization. For the dataset with more complex distribution in float32. The difference will be more obvious."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Product Quantizer Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 46.7 s, sys: 22.3 ms, total: 46.7 s\n",
      "Wall time: 1.36 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "M = 16\n",
    "nbits = 8\n",
    "metric = faiss.METRIC_L2\n",
    "\n",
    "index = faiss.IndexPQ(d, M, nbits, metric)\n",
    "\n",
    "index.train(corpus)\n",
    "index.add(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1min 37s, sys: 106 ms, total: 1min 37s\n",
      "Wall time: 2.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "D, I = index.search(query, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Recall: 0.630898999999999\n"
     ]
    }
   ],
   "source": [
    "recall = compute_recall(I, I_truth)\n",
    "print(f\"Recall: {recall}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Product quantizer index is not standout in any one of the aspect. But it somewhat balance the tradeoffs. It is widely used in real applications with the combination of other indexes such as IVF or HNSW."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
